{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Example Masks\n",
    "\n",
    "In this notebook, we plot the [Grad-CAM](https://arxiv.org/abs/1610.02391) figures from the paper. For more info and other examples, have a look at [our README](https://github.com/ecs-vlc/fmix).\n",
    "\n",
    "**Note**: The easiest way to use this is as a colab notebook, which allows you to dive in with no setup.\n",
    "\n",
    "First, we load dependencies and some data from CIFAR-10:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "import models\n",
    "from torchbearer import Trial\n",
    "import cv2\n",
    "import torch.nn.functional as F\n",
    "\n",
    "inv_norm = transforms.Normalize((-0.4914/0.2023, -0.4822/0.1994, -0.4465/0.2010), (1/0.2023, 1/0.1994, 1/0.2010))\n",
    "valset = torchvision.datasets.CIFAR10(root='./data/cifar', train=False, download=True,\n",
    "                                           transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465),\n",
    "                             (0.2023, 0.1994, 0.2010)),]))\n",
    "\n",
    "valloader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=True, num_workers=8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Wrapper\n",
    "\n",
    "Next, we define a `ResNet` wrapper that will iterate up to some given block `k`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ResNet_CAM(nn.Module):\n",
    "    def __init__(self, net, layer_k):\n",
    "        super(ResNet_CAM, self).__init__()\n",
    "        self.resnet = net\n",
    "        convs = nn.Sequential(*list(net.children())[:-1])\n",
    "        self.first_part_conv = convs[:layer_k]\n",
    "        self.second_part_conv = convs[layer_k:]\n",
    "        self.linear = nn.Sequential(*list(net.children())[-1:])\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.first_part_conv(x)\n",
    "        x.register_hook(self.activations_hook)\n",
    "        x = self.second_part_conv(x)\n",
    "        x = F.adaptive_avg_pool2d(x, (1,1))\n",
    "        x = x.view((1, -1))\n",
    "        x = self.linear(x)\n",
    "        return x\n",
    "    \n",
    "    def activations_hook(self, grad):\n",
    "        self.gradients = grad\n",
    "    \n",
    "    def get_activations_gradient(self):\n",
    "        return self.gradients\n",
    "    \n",
    "    def get_activations(self, x):\n",
    "        return self.first_part_conv(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grad-CAM\n",
    "\n",
    "Now for the Grad-CAM code, adapted from [implementing-grad-cam-in-pytorch](https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def superimpose_heatmap(heatmap, img):\n",
    "    resized_heatmap = cv2.resize(heatmap.numpy(), (img.shape[2], img.shape[3]))\n",
    "    resized_heatmap = np.uint8(255 * resized_heatmap)\n",
    "    resized_heatmap = cv2.applyColorMap(resized_heatmap, cv2.COLORMAP_JET)\n",
    "    superimposed_img = torch.Tensor(cv2.cvtColor(resized_heatmap, cv2.COLOR_BGR2RGB)) * 0.006 + inv_norm(img[0]).permute(1,2,0)\n",
    "    \n",
    "    return superimposed_img\n",
    "\n",
    "def get_grad_cam(net, img):\n",
    "    net.eval()\n",
    "    pred = net(img)\n",
    "    pred[:,pred.argmax(dim=1)].backward()\n",
    "    gradients = net.get_activations_gradient()\n",
    "    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])\n",
    "    activations = net.get_activations(img).detach()\n",
    "    for i in range(activations.size(1)):\n",
    "        activations[:, i, :, :] *= pooled_gradients[i]\n",
    "    heatmap = torch.mean(activations, dim=1).squeeze()\n",
    "    heatmap = np.maximum(heatmap, 0)\n",
    "    heatmap /= torch.max(heatmap)\n",
    "    \n",
    "    return torch.Tensor(superimpose_heatmap(heatmap, img).permute(2,0,1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Models From `torch.hub`\n",
    "\n",
    "Next, we load in the models from `torch.hub`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading: \"https://github.com/ecs-vlc/FMix/archive/master.zip\" to /home/ethan/.cache/torch/hub/master.zip\n",
      "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n",
      "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n",
      "Using cache found in /home/ethan/.cache/torch/hub/ecs-vlc_FMix_master\n"
     ]
    }
   ],
   "source": [
    "baseline_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_baseline', pretrained=True)\n",
    "\n",
    "fmix_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_fmix', pretrained=True)\n",
    "\n",
    "mixup_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_mixup', pretrained=True)\n",
    "\n",
    "fmix_plus_net = torch.hub.load('ecs-vlc/FMix:master', 'preact_resnet18_cifar10_fmixplusmixup', pretrained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plots\n",
    "\n",
    "Finally, generate and save the Grad-CAM plots:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_k = 4\n",
    "n_imgs = 10\n",
    "\n",
    "baseline_cam_net = ResNet_CAM(baseline_net, layer_k)\n",
    "fmix_cam_net = ResNet_CAM(fmix_net, layer_k)\n",
    "mixup_cam_net = ResNet_CAM(mixup_net, layer_k)\n",
    "fmix_plus_cam_net = ResNet_CAM(fmix_plus_net, layer_k)\n",
    "\n",
    "imgs = torch.Tensor(5, n_imgs, 3, 32, 32)\n",
    "it = iter(valloader)\n",
    "for i in range(0,n_imgs):\n",
    "    img, _ = next(it)\n",
    "    imgs[0][i] = inv_norm(img[0])\n",
    "    imgs[1][i] = get_grad_cam(baseline_cam_net, img)\n",
    "    imgs[2][i] = get_grad_cam(mixup_cam_net, img)\n",
    "    imgs[3][i] = get_grad_cam(fmix_cam_net, img)\n",
    "    imgs[4][i] = get_grad_cam(fmix_plus_cam_net, img)\n",
    "\n",
    "torchvision.utils.save_image(imgs.view(-1, 3, 32, 32), \"gradcam_at_layer\" + str(layer_k) + \".png\",nrow=n_imgs, pad_value=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
